"""
This examples show how to train a Cross-Encoder for the MS Marco dataset (https://github.com/microsoft/MSMARCO-Passage-Ranking).

In this example we use a knowledge distillation setup. Sebastian Hofstätter et al. trained in https://arxiv.org/abs/2010.02666
an ensemble of large Transformer models for the MS MARCO datasets and combines the scores from a BERT-base, BERT-large, and ALBERT-large model.

We use the logits scores from the ensemble to train a smaller model. We found that the MiniLM model gives the best performance while
offering the highest speed.

The resulting Cross-Encoder can then be used for passage re-ranking: You retrieve for example 100 passages
for a given query, for example with Elasticsearch, and pass the query+retrieved_passage to the CrossEncoder
for scoring. You sort the results then according to the output of the CrossEncoder.

This gives a significant boost compared to out-of-the-box Elasticsearch / BM25 ranking.

Running this script:
python train_cross-encoder-v2.py
"""

from torch.utils.data import DataLoader
from sentence_transformers import LoggingHandler, util
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
from sentence_transformers import InputExample
import logging
from datetime import datetime
import gzip
import os
import tarfile
import torch

#### Just some code to print debug information to stdout
logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)
#### /print debug information to stdout


# First, we define the transformer model we want to fine-tune
model_name = "microsoft/MiniLM-L12-H384-uncased"
train_batch_size = 32
num_epochs = 1
model_save_path = (
    "output/training_ms-marco_cross-encoder-v2-"
    + model_name.replace("/", "-")
    + "-"
    + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
)


# We set num_labels=1 and set the activation function to Identity, so that we get the raw logits
model = CrossEncoder(model_name, num_labels=1, max_length=512, default_activation_function=torch.nn.Identity())


### Now we read the MS Marco dataset
data_folder = "msmarco-data"
os.makedirs(data_folder, exist_ok=True)


#### Read the corpus files, that contain all the passages. Store them in the corpus dict
corpus = {}
collection_filepath = os.path.join(data_folder, "collection.tsv")
if not os.path.exists(collection_filepath):
    tar_filepath = os.path.join(data_folder, "collection.tar.gz")
    if not os.path.exists(tar_filepath):
        logging.info("Download collection.tar.gz")
        util.http_get("https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz", tar_filepath)

    with tarfile.open(tar_filepath, "r:gz") as tar:
        tar.extractall(path=data_folder)

with open(collection_filepath, "r", encoding="utf8") as fIn:
    for line in fIn:
        pid, passage = line.strip().split("\t")
        corpus[pid] = passage


### Read the train queries, store in queries dict
queries = {}
queries_filepath = os.path.join(data_folder, "queries.train.tsv")
if not os.path.exists(queries_filepath):
    tar_filepath = os.path.join(data_folder, "queries.tar.gz")
    if not os.path.exists(tar_filepath):
        logging.info("Download queries.tar.gz")
        util.http_get("https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz", tar_filepath)

    with tarfile.open(tar_filepath, "r:gz") as tar:
        tar.extractall(path=data_folder)


with open(queries_filepath, "r", encoding="utf8") as fIn:
    for line in fIn:
        qid, query = line.strip().split("\t")
        queries[qid] = query


### Now we create our  dev data
train_samples = []
dev_samples = {}

# We use 200 random queries from the train set for evaluation during training
# Each query has at least one relevant and up to 200 irrelevant (negative) passages
num_dev_queries = 200
num_max_dev_negatives = 200

# msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz and msmarco-qidpidtriples.rnd-shuf.train.tsv.gz is a randomly
# shuffled version of qidpidtriples.train.full.2.tsv.gz from the MS Marco website
# We extracted in the train-eval split 500 random queries that can be used for evaluation during training
train_eval_filepath = os.path.join(data_folder, "msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz")
if not os.path.exists(train_eval_filepath):
    logging.info("Download " + os.path.basename(train_eval_filepath))
    util.http_get("https://sbert.net/datasets/msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz", train_eval_filepath)

with gzip.open(train_eval_filepath, "rt") as fIn:
    for line in fIn:
        qid, pos_id, neg_id = line.strip().split()

        if qid not in dev_samples and len(dev_samples) < num_dev_queries:
            dev_samples[qid] = {"query": queries[qid], "positive": set(), "negative": set()}

        if qid in dev_samples:
            dev_samples[qid]["positive"].add(corpus[pos_id])

            if len(dev_samples[qid]["negative"]) < num_max_dev_negatives:
                dev_samples[qid]["negative"].add(corpus[neg_id])

dev_qids = set(dev_samples.keys())

# Read our training file
# As input examples, we provide the (query, passage) pair together with the logits score from the teacher ensemble
teacher_logits_filepath = os.path.join(data_folder, "bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv")
train_samples = []
if not os.path.exists(teacher_logits_filepath):
    util.http_get(
        "https://zenodo.org/record/4068216/files/bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv?download=1",
        teacher_logits_filepath,
    )

with open(teacher_logits_filepath) as fIn:
    for line in fIn:
        pos_score, neg_score, qid, pid1, pid2 = line.strip().split("\t")

        if qid in dev_qids:  # Skip queries in our dev dataset
            continue

        train_samples.append(InputExample(texts=[queries[qid], corpus[pid1]], label=float(pos_score)))
        train_samples.append(InputExample(texts=[queries[qid], corpus[pid2]], label=float(neg_score)))

# We create a DataLoader to load our train samples
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size, drop_last=True)

# We add an evaluator, which evaluates the performance during training
# It performs a classification task and measures scores like F1 (finding relevant passages) and Average Precision
evaluator = CERerankingEvaluator(dev_samples, name="train-eval")

# Configure the training
warmup_steps = 5000
logging.info("Warmup-steps: {}".format(warmup_steps))


# Train the model
model.fit(
    train_dataloader=train_dataloader,
    loss_fct=torch.nn.MSELoss(),
    evaluator=evaluator,
    epochs=num_epochs,
    evaluation_steps=5000,
    warmup_steps=warmup_steps,
    output_path=model_save_path,
    optimizer_params={"lr": 7e-6},
    use_amp=True,
)

# Save latest model
model.save(model_save_path + "-latest")
